Whisper是由OpenAI开源的语言识别模型,Whisper JAX则是JAX的实现版本。主要基于? Hugging Face Transformers的Whisper实现。与OpenAI的PyTorch代码相比,Whisper JAX运行速度快了70多倍,是目前最快的Whisper实现。
JAX代码兼容CPU、GPU和TPU,并且可以作为独立的运行程序(参见Pipeline Usage)或推理端点(参见Creating an Endpoint)运行。
Whisper的Flax权重文件与JAX版本的预训练结果文件完全兼容,各个版本的信息和能力如下:
模型size | 参数数量 | 是否仅支持英文 | 多语言能力 |
---|---|---|---|
tiny | 3900万 | Y | Y |
base | 7400万 | Y | Y |
small | 2.44亿 | Y | Y |
medium | 7.69亿 | Y | Y |
large | 15.5亿 | N | Y |
large-v2 | 15.5亿 | N | Y |
官方公开的是PyTorch版本,需要先使用from_pt来将其转换成Flax版本。各个不同版本的Whisper对比结果:
Whisper发布者 | 代码框架 | 后端硬件 | 1分钟 | 10分钟 | 1个小时 |
---|---|---|---|---|---|
OpenAI | PyTorch | GPU | 13.8 | 108.3 | 1001 |
Transformers | PyTorch | GPU | 4.54 | 20.2 | 126.1 |
Whisper JAX | JAX | GPU | 1.72 | 9.38 | 75.3 |
Whisper JAX | JAX | TPU | 0.45 | 2.01 | 13.8 |
上表中的1分钟、10分钟和1个小时分别代表不同模型转换这么长时间语音所需要的推断时间,单位是秒。可以看到,如果都是用GPU的话,Whisper一个小时的语音转换只要75秒,而OpenAI官方的模型需要1001秒,也就是十几分钟!如果使用TPU,那么1个小时的语音转换只要13.8秒!不得不说,谷歌全家桶的性能非常赞!
Whisper JAX模型的GitHub开源地址: https://github.com/sanchit-gandhi/whisper-jax